public class Flatten {

    TreeNode end;
    public void flatten(TreeNode root) {
        end = new TreeNode();
        process(root);
        end = new TreeNode();
        dfs(root);
    }

    public void process(TreeNode node) {
        if (node == null) return;
        end.left = node;
        end = end.left;
        process(node.left);
        process(node.right);
        node.right = null;
    }

    public void dfs(TreeNode node) {
        if (node == null) return;
        end.right = node;
        end = end.right;
        dfs(node.left);
        node.left = null;
    }
}
